import torch
import math
import argparse
from tqdm import tqdm
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, roc_auc_score
import decimal
from transformers import *
import os
import sys
sys.path.insert(0, os.getcwd() + "/Generated-Text-Detector/RobertaDetect/")

from dataset import EncodeEvalData
from detector import RobertaForTextGenClassification


def float_range(start, stop, step):
    while start < stop:
        yield float(start)
        start += decimal.Decimal(step)

def calculate_eval_metrics(far, pd):

    pd_at_far = 0.0
    pd_at_eer = 0.0
    far_at_eer = 0.0

    for i in range(len(far)):
      if far[i] > 0.1:
        pd_at_far = pd[i-1]
        break

    for i in range(len(far)):
      if pd[i] > 1 - far[i]:
        pd_at_eer = (pd[i-1] + pd[i])/2
        far_at_eer = (far[i-1] + far[i])/2
        break
    
    
    print("pD @ 0.1 FAR = %.3f" % (pd_at_far))
    print("pD @ EER = %.3f" % (pd_at_eer))
    print("FAR @ EER = %.3f" % (far_at_eer))

class GeneratedTextDetection:
    """
    Detector class
    """

    def __init__(self, args):
        torch.manual_seed(1000)

        self.args = args

        # Load the model from checkpoints
        self.init_dict = self._init_detector()

    def _init_detector(self):

        init_dict = {"kn_model": None, "kn_tokenizer": None}

        model_name = 'roberta-large' if self.args.kn_large else 'roberta-base'
        tokenization_utils.logger.setLevel('ERROR')
        tokenizer = RobertaTokenizer.from_pretrained(model_name)
        model = RobertaForTextGenClassification.from_pretrained(model_name).to(self.args.device)

        # Load the model from checkpoints
        if self.args.device == "cpu":
            model.load_state_dict(torch.load((self.args.check_point + '{}.pt').format(self.args.known_model_name),
                                                map_location='cpu')['model_state_dict'])
        else:
            print((self.args.check_point + '{}.pt').format(self.args.known_model_name))
            model.load_state_dict(
                torch.load((self.args.check_point + '{}.pt').format(self.args.known_model_name))['model_state_dict'])

        init_dict["kn_model"] = model
        init_dict["kn_tokenizer"] = tokenizer

        return init_dict

    def evaluate(self, input_text):
        """
           Method that runs the evaluation 
        """

        # Encapsulate the inputs
        eval_dataset = EncodeEvalData(input_text, self.init_dict["kn_tokenizer"], self.args.max_sequence_length)
        eval_loader = DataLoader(eval_dataset)

        # Dictionary will contain all the scores and evidences generated by the model
        results = {"cls": [], "LLR_score": [], "prob_score": {"cls_0": [], "cls_1": []}, "generator": None}

        self.init_dict["kn_model"].eval()


        with torch.no_grad():
            for texts, masks in eval_loader:

                texts, masks = texts.to(self.args.device), masks.to(self.args.device)

                output_dic = self.init_dict["kn_model"](texts, attention_mask=masks)
                disc_out = output_dic["logits"]

                cls0_prob = disc_out[:, 0].tolist()
                cls1_prob = disc_out[:, 1].tolist()

                results["prob_score"]["cls_0"].extend(cls0_prob)
                results["prob_score"]["cls_1"].extend(cls1_prob)

                prior_llr = math.log10(self.args.kn_priors[0]/self.args.kn_priors[1])

                results["LLR_score"].extend([math.log10(prob/(1-prob)) + prior_llr for prob in cls1_prob])

                _, predicted = torch.max(disc_out, 1)

                results["cls"].extend(predicted.tolist())
                      
        return results

def main():
    parser = argparse.ArgumentParser(
        description='Generated Text: Discriminator'
    )

    # Input data and files
    parser.add_argument('--known_model_name', default="", type=str,
                        help='name of the known generator detector model')

    parser.add_argument('--check_point', default="", type=str,
                        help='saved model checkpoint directory')

    # Model parameters
    parser.add_argument('--device', type=str, default=None)
    parser.add_argument('--kn_priors', type=list, default=[0.5, 0.5])
    parser.add_argument('--batch-size', type=int, default=1)
    parser.add_argument('--max-sequence-length', type=int, default=256)
    parser.add_argument('--kn_large', type=bool, default=False)


    # sources = ['ctrl', 'fair_wmt19', 'gpt2_xl', 'gpt3', 'grover_mega', 'xlm', 'chatgpt']
    # targets = ['ctrl', 'fair_wmt19', 'gpt2_xl', 'gpt3', 'grover_mega', 'xlm', 'chatgpt']

    mode = 'conda'  
    loss_mode = 'simclr'  ## only simclr supported for now
    transformation = '_syn_rep'  ## only synonym replacement supported for now
    

    # change as needed:
    src = "fair_wmt19"  # one of ['ctrl', 'fair_wmt19', 'gpt2_xl', 'gpt3', 'grover_mega', 'xlm', 'chatgpt']
    tgt = "ctrl"  # one of ['ctrl', 'fair_wmt19', 'gpt2_xl', 'gpt3', 'grover_mega', 'xlm', 'chatgpt']


    if src==tgt:
        print("Source and target should be different generators.")
        exit(0)

    if mode == 'conda':
        args = parser.parse_args(args=['--check_point='+os.getcwd()+'/models/',\
        '--known_model_name='+src+'_'+tgt+transformation+'_'+loss_mode])
    else:
        raise ValueError("Invalid mode")

    if args.device is None:
        args.device = f'cuda:{0}' if torch.cuda.is_available() else 'cpu'


    predict_prob = []

    y = []

    detector = GeneratedTextDetection(args)

    
    if tgt=='chatgpt':
        test_data_dir = '/home/abhatt43/Data_for_Testing/ChatGPT/'

        real_test = pd.read_json(test_data_dir + 'chatgpt_real.test.jsonl', lines=True, orient='records')
        fake_test = pd.read_json(test_data_dir + 'chatgpt_fake.test.jsonl', lines=True, orient='records')
        print("---Loaded "+ tgt + " test data---")
    
    else:
        test_data_dir = '/home/abhatt43/Data_for_Testing//TuringBench/TT_'+tgt+'/'

        real_test = pd.read_json(test_data_dir + 'tb_tt_'+tgt+'_real.test.jsonl', lines=True, orient='records')
        fake_test = pd.read_json(test_data_dir + 'tb_tt_'+tgt+'_fake.test.jsonl', lines=True, orient='records')
        print("---Loaded "+ tgt + " test data---")


    test_data = fake_test.append(real_test)

    tp = 0
    tn = 0
    fn = 0
    fp = 0  
    predicted_ys = []

    for value in tqdm(test_data.itertuples()):
      
      main_body_text = value.text
      # main_body_text = value.text.replace("\n\n\n", " ")

      if main_body_text == "":
        continue

      results = detector.evaluate([main_body_text])

      y.append(value.label)

      # predict_prob.append(results["prob_score"]["cls_1"])

      predict_prob.append(results["LLR_score"][0])

      predicted = results["cls"][0]

      predicted_ys.append(predicted)
      tp += ((predicted == value.label) & (value.label == 1))
      tn += ((predicted == value.label) & (value.label == 0))
      fn += ((predicted != value.label) & (value.label == 1))
      fp += ((predicted != value.label) & (value.label == 0))

    try:
        recall = float(tp) / (tp+fn)
    except ZeroDivisionError:
        recall = "undefined"

    try:
        precision = float(tp) / (tp+fp)
    except ZeroDivisionError:
        precision = "undefined"

    try:
        f1_score = 2 * float(precision) * recall / (precision + recall)
    except:
        f1_score = "undefined"

    print('TP: %d' % (
        tp))
    print('TN: %d' % (
        tn))
    print('FP: %d' % (
        fp))
    print('FN: %d' % (
        fn))

    print('Accuracy of the discriminator: %d %%' % (
            100 * (tp + tn) / (tp + tn + fp + fn)))
    
    if type(recall)!=str:
        print('Recall of the discriminator: %d %%' % (
            100 * recall))
    else:
        print("Recall: ", recall)

    if type(precision)!=str:
        print('Precision of the discriminator: %d %%' % (
            100 * precision))
    else:
        print("Precison: ", precision)
    
    if type(f1_score)!=str:
        print('f1_score of the discriminator: %d %%' % (
            100 * f1_score))
    else:
        print("F1: ", f1_score)
    
    # calculate scores
    lr_auc = roc_auc_score(y, predict_prob)

    # summarize scores
    print("\n")
    print(" ----- Evaluation Metrics -----")
    print()
    print('Classifier: ROC AUC=%.3f' % (lr_auc))

    # calculate roc curves
    lr_fpr, lr_tpr, _ = roc_curve(y, predict_prob)

    calculate_eval_metrics(lr_fpr, lr_tpr)

    eq_fpr = list(float_range(0, 1, 1 / len(lr_fpr)))
    eq_tpr = [item for item in eq_fpr]

    from matplotlib import pyplot
    # plot the roc curve for the model
    pyplot.plot(lr_fpr, lr_tpr, marker='.', label='ConDA')
    pyplot.plot(eq_fpr, eq_tpr, marker='.', label='Random Chance')

    pyplot.xlabel('Probability of False Alarm')
    pyplot.ylabel('Probability of Detection')
    pyplot.legend()
    pyplot.show()


if __name__ == "__main__":
    main()
